{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "29888eda-2755-4add-af9c-8927afb07db4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "tokenizer = AutoTokenizer.from_pretrained('human_gpt2-v1')\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# if tokenizer.pad_token is None:\n",
    "#     tokenizer.add_special_tokens({'pad_token': '[PAD]'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "75e3778d-9cbb-48c4-8203-0f71f485a49d",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model = AutoModel.from_pretrained('human_gpt2-v1')\n",
    "#model.config.eos_token_id\n",
    "# print(model.config.pad_token_id)\n",
    "#model.config.pad_token_id = model.config.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "be845310-8215-4434-bb92-d44c343f274e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "transformers.models.gpt2.modeling_gpt2\n"
     ]
    }
   ],
   "source": [
    "gena_module_name = full_model.__class__.__module__\n",
    "print(gena_module_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bb9ee647-cac2-4475-ac44-2f11aef99735",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import importlib\n",
    "# available class names:\n",
    "# - BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction,\n",
    "# - BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification,\n",
    "# - BertForQuestionAnswering\n",
    "# check https://huggingface.co/docs/transformers/model_doc/bert\n",
    "myclass = importlib.import_module(gena_module_name)\n",
    "#dir(myclass)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "62ebfcc0-5d46-4c3b-b462-a3842a985e9b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cls = getattr(importlib.import_module(gena_module_name), 'GPT2ForSequenceClassification')\n",
    "cls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4523e9b6-3613-41e4-b81d-c139d044e70c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at human_gpt2-v1 and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = cls.from_pretrained('human_gpt2-v1', num_labels=2)\n",
    "#dir(model)\n",
    "#model.config.eos_token_id\n",
    "#print(model.config.pad_token_id)\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "#model.resize_token_embeddings(len(tokenizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd024199-16e6-4eb1-b404-c49dc535469b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "# load ~11k samples from promoters prediction dataset\n",
    "dataset = load_dataset(\"yurakuratov/example_promoters_2k\")['train'].train_test_split(test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a3a5213d-66f0-4d17-876c-de4426145412",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sequence', 'promoter_presence'],\n",
       "        num_rows: 10656\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sequence', 'promoter_presence'],\n",
       "        num_rows: 1184\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "36a23870-1f56-466e-b2bb-c5047072b8f0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sequence': 'GAGCACATTCGCCTGCGTGCGCACTCACACACACGTTCAAAAAGAGTCCATTCGATTCTGGCAGTAGGCTGAAAACGATCCATATTGACGCGAATAACATCCCAACTGTAAATATGGATTCGTCTACTCGAGGCTGTGTCTCATTTCACCATAGAGGAAAATTATGTCATTTAAAACAGTGATAAAACGCGAAGAGTTTCGATCTTCTGTAAATGGCTCTCACTTTGAATTGTGGGGGAAAAAAATCACTAATTGGGAATAGCCAGAGAAGGGGGAATTTGTTGAGCGCCCAAGGCCAATCTCAAAATATTTTATTTCATGCCAATGTGGGAGAGGGAAGGGAGTGTGGCTTCGCTGTTTGCCGCTTTTGCATTTTCCACATCTGGCCTACAGACCTAATTCTAACTTGCGATCTCGCTTGATGAAATTCGATCTTCGTGGTTTAAATACCTTTCAATTTGGGACGATTTAACCATTGCAGCAGTTCAAAAAACAAACCCTGAGCATTATCTGTAAAACCAATATGGAAAAAAGAAAGAAAACATGCTGCAAACACACAAATCGCAACCATTGTCTTTAACACTCGCATACACACACATACACATAAACACACGGAAAGATGGGGGAGGGGAGCGAGTGGAGAGAGAGAGACACACGGATATAGACCGACGAGAGTCAATGCGAGTCAAGTTAACGCTTTCTCCGCCCCGATGCACAAGTACCCCCCTCCCGTGTTAATTAATTCCAAACAAAGCAAGCCAATCAAGAAATGCATTATTATTTTCAAGCAGAGAGAGACGAGTGCCATGTATTTTTTAATTGATTTATTTATTTGGAGGACATTAATTCTCGGTCTGAGGCTGATTTTCAAACCGTTTGCAAAGCGCGGCTCCATTGTTCGCCAGGCGCGCAGACCCCGCCCCCTGCTTTGTGTGCGGCGCGCGGATTGGCCGACGCGCGCGGGGGCCGCGTCAGCGCCCGCGAGCCCATTCATCTGTGCACTTGGGCGTTGGACCCCGCATCTTATTAGCAACCAGGGAGATTTCTCCATTTTCCTCTTGTCTACAGTGCGGCTACAAATCTGGGATTTTTTTATTACTTCTTTTTTTTTCGAACTACACTTGGGCTCCTTTTTTTGTGCTCGACTTTTCCACCCTTTTTCCCTCCCTCCTGTGCTGCTGCTTTTTGATCTCTTCGACTAAAATTTTTTTATCCGGAGTGTATTTAATCGGTTCTGTTCTGTCCTCTCCACCACCCCCACCCCCCTCCCTCCGGTGTGTGTGCCGCTGCCGCTGTTGCCGCCGCCGCTGCTGCTGCTGCTCGCCCCGTCGTTACACCAACCCGAGGCTCTTTGTTTCCCCTCTTGGATCTGTTGAGTTTCTTTGTTGAAGAAGCCAGCATGGGTGCCCAGTTCTCCAAGACCGCAGCGAAGGGAGAAGCCGCCGCGGAGAGGCCTGGGGAGGCGGCTGTGGCCTCGTCGCCTTCCAAAGCGAACGGACAGGTAAATTCTACCTGTGGTTGTGGTCATTTATTTCGTGTCTTTCTCTCCTTCCCTCCTGGTGGCTCCCAAGGCTCATTCGTGGAGAGGAGTGTGGCTGGATTCTAGTCTGAAAGTTGAATGGAAATGGATAGCCTAAATTGTGGATTTTTCATAGCCGGGGTCCCCCCACACCAAAGGAGTGGTTCCTTAGAATCCTGGAGAATTCTTGAGTAGACGGGAGGAAAGGAGTGATAAATCGTCTAAAATGGCGTGGTTTGGAGATTGGCAGAACGTGACTAGCTAGGTGCCCTCCTTCCCCCGCCCCCAGCCAGGTTTCATTTGTGTTTGGGGGCACAATGGTGTTTGGGTGTCTCTCGGGTTTGTTGTCGAAGCTGGATTGTCTTCGAGATTTCTGAACTGGTGGTGTGATTGTGTGTGGCGGGGGGGGGAGTCGGGAAAGGGGGGCGCAATGGGGATGGTGGTGGAGAAAGGCGAGGAGGGAAAGGCTGTCTTATTGTAT',\n",
       " 'promoter_presence': 1}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "19e22500-c1e0-444e-9b5c-2a1a2af81997",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# base pairs:  2000\n"
     ]
    }
   ],
   "source": [
    "print('# base pairs: ', len(dataset['train'][0]['sequence']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "16db0a5d-7aec-4d59-b44b-d22ff43045fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tokens:  G AGCAC ATTCGCC TGCG TGCGC ACTC ACACACACG TTCAAAA AGAGTCC ATTCG ATTCTGGC AGTAGGC TG AAAACG ATCC ATATTG ACGC GAA TAAC ATCCC AACTGTAA ATATGG ATTCG TCTAC TCGAGGC TGTGTCTC ATTTCACC ATAG AGGAAAA TTATG TCATT TAAAAC AGTG ATAAAAC GCG AAG AGTTTCG ATCTTC TGTAA ATGGC TCTCAC TTTGAATTG TGGGGG AAAAAA ATCACTAA TTGGGAA TAGCC AGAGAAGG GGG AATTTG TTG AGCGCCC AAGGCC AATCTC AAAATATT TTATT TCATGCC AATGTGGG AGAGGG AAGGGAG TGTGGCTTC GCTG TTTGCC GC TTTTGC ATTTTCC ACATC TGGCC TACAGACC TAATTC TAAC TTGCG ATCTCGC TTG ATGAA ATTCG ATCTTCG TGGTTTAA ATACC TTTC AATTTGGG ACG ATTTAACC ATTGC AGCAGTTC AAAAAAC AAACCC TGAGC ATTATC TGTAA AACC AATATGG AAAAAAG AAAGAAAAC ATGC TGCAA ACACACAA ATCGC AACCATTG TCTTTAAC ACTCGC ATACACACAC ATACAC ATAAACAC ACGG AAAG ATGGGGG AGGGG AGCGAG TGG AGAGAGAG AGACAC ACGG ATAT AGACCG ACG AGAG TCAA TGCG AGTCAAG TTAACGC TTTCTCC GCCCCG ATGC ACAAG TACCCCCC TCCCGTG TTAA TTAATTCC AAACAAAGC AAGCC AATC AAGAAATGC ATTATT ATTTTC AAGC AGAGAGAG ACG AGTGCC ATG TATTTTTT AATTG ATTTATTTATT TGGAGG ACATT AATTCTC GG TCTG AGGCTG ATTTTC AAACCG TTTGC AAAGCGC GGCTCC ATTGTTC GCC AGGC GCGC AGACCCC GCCCCC TGCTTTGTG TGC GGCGC GCGG ATTGGCC G ACGCGC GCGG GGGCC GCG TCAGC GCCCGC GAGCCC ATTCATCTG TGCAC TTGGGCG TTGG ACCCCGC ATCTTATT AGCAACC AGGG AGATTTC TCC ATTTTCC TCTTG TCTAC AGTGC GGC TACAAATC TGGG ATTTTTT TATT ACTTC TTTTTTTT TCGAAC TACAC TTGGGC TCC TTTTTTTG TGC TCGAC TTTTCC ACCCTTTT TCCCTCCC TCCTG TGCTGCTGC TTTTTG ATCTC TTCGAC TAAAATTTT TTTATCC GG AGTGTATT TAATC GG TTCTG TTCTGTCC TCTCC ACCACCCCC ACCCCCC TCCCTCC GG TGTGTGTGCC GCTGCC GCTG TTGCC GCCGCC GC TGCTGC TGCTGCTC GCCCCG TCG TTACACC AACCCG AGGC TCTTTG TTTCCCC TCTTGG ATCTGTTG AGTTTC TTTG TTGAAG AAGCC AGCATGGG TGCCC AGTTCTCC AAGACC GC AGCG AAGGGAG AAGCC GCC GCGG AGAGGCC TGGGGAGGC GGCTG TGGCCTCG TCGCC TTCC AAAGCG AACGG ACAGG TAAATTC TACC TGTGG TTGTGG TCATTTATT TCG TGTCTTTC TCTCC TTCCCTCC TGGTGGC TCCC AAGGC TC ATTCG TGGAG AGGAG TGTGGC TGGATTC TAGTCTG AAAG TTGAATGGAA ATGG ATAGCC TAAATTG TGGATT TTTC ATAGCC GGGG TCCCCCC ACACC AAAGGAG TGGTTCC TTAGAA TCC TGGAGAA TTC TTGAG TAGAC GGG AGGAA AGGAGTG ATAA ATCG TCTAA AATGGCG TGG TTTGGAG ATTGGC AGAACG TGAC TAGCTAGG TGCCCTCC TTCCCCC GCCCCC AGCC AGGTTTC ATTTGTG TTTGGGGGC ACAA TGGTG TTTGGG TGTCTC TCGGG TTTGTTG TCGAAGC TGGATTG TCTTC GAG ATTTCTG AAC TGGTGG TGTGATTG TGTGTGGC GGGG GGGGG AGTCGGG AAAGGG GGGCGC AATGGGG ATGG TGGTGG AGAA AGGCG AGGAGGG AAAGGC TGTC TTATTG TAT\n"
     ]
    }
   ],
   "source": [
    "print('tokens: ', ' '.join(tokenizer.tokenize(dataset['train'][0]['sequence'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9876c07a-6332-40cc-aa84-2d26e6a3a5ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# tokens:  350\n"
     ]
    }
   ],
   "source": [
    "print('# tokens: ', len(tokenizer.tokenize(dataset['train'][0]['sequence'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7c4b0442-5fdf-4845-8bf2-4c3b6cc86e84",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 10656/10656 [00:01<00:00, 9309.45 examples/s]\n",
      "Map: 100%|██████████| 1184/1184 [00:00<00:00, 10546.08 examples/s]\n"
     ]
    }
   ],
   "source": [
    "def preprocess_labels(example):\n",
    "  example['label'] = example['promoter_presence']\n",
    "  return example\n",
    "\n",
    "dataset = dataset.map(preprocess_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "46e912af-1aab-4ed2-8007-0f3f970b2255",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(examples):\n",
    "  # just truncate right, but for some tasks symmetric truncation from left and right is more reasonable\n",
    "  # set max_length to 128 to make experiments faster\n",
    "  return tokenizer(examples[\"sequence\"], truncation=True, max_length=512) #max_length 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0ef16794-71fd-4404-a51e-b55b5cae9aab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 10656/10656 [00:02<00:00, 4268.87 examples/s]\n",
      "Map: 100%|██████████| 1184/1184 [00:00<00:00, 4375.16 examples/s]\n"
     ]
    }
   ],
   "source": [
    "tokenized_dataset = dataset.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "765c199b-9d13-4fd2-ac51-1e5d154766fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "afd6f48f-5524-4740-b8ea-54a79fd79dad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sequence', 'promoter_presence', 'label', 'input_ids', 'attention_mask'],\n",
       "        num_rows: 10656\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sequence', 'promoter_presence', 'label', 'input_ids', 'attention_mask'],\n",
       "        num_rows: 1184\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e1b708bb-21e4-416d-b191-79f302ca4e93",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "import numpy as np\n",
    "\n",
    "#tokenizer.pad_token = tokenizer.eos_token \n",
    "#tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    return {'accuracy': (predictions==labels).sum() / len(labels)}\n",
    "\n",
    "# change training hyperparameters to archive better quality\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"test_run\",\n",
    "    learning_rate=1e-4,\n",
    "    lr_scheduler_type=\"constant_with_warmup\",\n",
    "    warmup_ratio=0.1,\n",
    "    optim='adamw_torch',\n",
    "    weight_decay=0.0,\n",
    "    per_device_train_batch_size=20,\n",
    "    per_device_eval_batch_size=20,\n",
    "    num_train_epochs=5,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_dataset[\"train\"],\n",
    "    eval_dataset=tokenized_dataset[\"test\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "082754e6-53fc-4a95-9d1e-887c42975d71",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1335' max='1335' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1335/1335 12:57, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.564000</td>\n",
       "      <td>0.418670</td>\n",
       "      <td>0.805743</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.344300</td>\n",
       "      <td>0.410089</td>\n",
       "      <td>0.804054</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.203800</td>\n",
       "      <td>0.525304</td>\n",
       "      <td>0.802365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.115400</td>\n",
       "      <td>0.622882</td>\n",
       "      <td>0.800676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.092000</td>\n",
       "      <td>1.022607</td>\n",
       "      <td>0.733953</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1335, training_loss=0.2639072503936425, metrics={'train_runtime': 781.0108, 'train_samples_per_second': 68.219, 'train_steps_per_second': 1.709, 'total_flos': 1.0180868548460544e+16, 'train_loss': 0.2639072503936425, 'epoch': 5.0})"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55c9d1eb-6fc6-451f-b596-870ccaa81d8d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
